# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Model."""
import numpy as np

from mindspore import log as logger
from ..common.tensor import Tensor
from ..nn.metrics import get_metrics
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
from .callback import _InternalCallbackParam, RunContext, _build_callbacks
from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
    _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _callback_wrapper
from ..nn.metrics import Loss
from ..nn.wrap import WithLossCell, WithEvalCell, \
    DataWrapper
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from .parallel_utils import ParallelMode
from ..common import dtype as mstype
from .dataset_helper import DatasetHelper
from . import amp


class Model:
    """
    High-Level API for Training or Testing.

    `Model` groups layers into an object with training and inference features.

    Args:
        network (Cell): The training or testing network.
        loss_fn (Cell): Objective function, if loss_fn is None, the
                             network should contain the logic of loss and grads calculation, and the logic
                             of parallel if needed. Default: None.
        optimizer (Cell): Optimizer for updating the weights. Default: None.
        metrics (Union[dict, set]): Dict or set of metrics to be evaluated by the model during
                        training and testing. eg: {'accuracy', 'recall'}. Default: None.
        eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as
                             `eval_network`. Default: None.
        eval_indexes (list): In case of defining the `eval_network`, if `eval_indexes` is None, all outputs of
                             `eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three
                             elements, representing the positions of loss value, predict value and label, the loss
                             value would be passed to `Loss` metric, predict value and label would be passed to other
                             metric. Default: None.
        amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed
            precision training. Supports [O0, O2]. Default: "O0".

            - O0: Do not change.
            - O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.

        loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
            scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument.
            e.g. Use `loss_scale_manager=None` to set the value.

    Examples:
        >>> class Net(nn.Cell):
        >>>     def __init__(self):
        >>>         super(Net, self).__init__()
        >>>         self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
        >>>         self.bn = nn.BatchNorm2d(64)
        >>>         self.relu = nn.ReLU()
        >>>         self.flatten = nn.Flatten()
        >>>         self.fc = nn.Dense(64*222*222, 3) # padding=0
        >>>
        >>>     def construct(self, x):
        >>>         x = self.conv(x)
        >>>         x = self.bn(x)
        >>>         x = self.relu(x)
        >>>         x = self.flatten(x)
        >>>         out = self.fc(x)
        >>>         return out
        >>>
        >>> net = Net()
        >>> loss = nn.SoftmaxCrossEntropyWithLogits()
        >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
        >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
        >>> dataset = get_dataset()
        >>> model.train(2, dataset)
    """

    def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None,
                 eval_indexes=None, amp_level="O0", **kwargs):
        self._network = network
        self._loss_fn = loss_fn
        self._optimizer = optimizer
        self._loss_scale_manager = None
        self._loss_scale_manager_set = False
        self._check_kwargs(kwargs)
        if 'loss_scale_manager' in kwargs:
            self._loss_scale_manager = kwargs['loss_scale_manager']
            self._loss_scale_manager_set = True
        self._amp_level = amp_level
        self._parallel_mode = _get_parallel_mode()
        self._device_number = _get_device_num()
        self._global_rank = _get_global_rank()
        self._parameter_broadcast = _get_parameter_broadcast()

        self._train_network = self._build_train_network()
        self._build_eval_network(metrics, eval_network, eval_indexes)

    def _check_kwargs(self, kwargs):
        for arg in kwargs:
            if arg not in ['loss_scale_manager']:
                raise  ValueError(f"Unsupport arg '{arg}'")

    def _build_train_network(self):
        """Build train network"""
        network = self._network
        if self._optimizer:
            if self._loss_scale_manager_set:
                network = amp.build_train_network(network,
                                                self._optimizer,
                                                self._loss_fn,
                                                level=self._amp_level,
                                                loss_scale_manager=self._loss_scale_manager)
            else:
                network = amp.build_train_network(network,
                                                self._optimizer,
                                                self._loss_fn,
                                                level=self._amp_level)
        elif self._loss_fn:
            network = WithLossCell(network, self._loss_fn)
        # If need to check if loss_fn is not None, but optimizer is None
        return network

    def _build_eval_network(self, metrics, eval_network, eval_indexes):
        """Build the network for evaluation."""
        self._metric_fns = get_metrics(metrics)
        if not self._metric_fns:
            return

        if eval_network is not None:
            if eval_indexes is not None and not (isinstance(eval_indexes, list) and len(eval_indexes) == 3):
                raise ValueError("Eval_indexes must be a list or None. If eval_indexes is a list, length of it \
                                 must be three. But got {}".format(eval_indexes))

            self._eval_network = eval_network
            self._eval_indexes = eval_indexes
        else:
            if self._loss_fn is None:
                raise ValueError("loss_fn can not be None.")
            self._eval_network = WithEvalCell(self._network, self._loss_fn)
            self._eval_indexes = [0, 1, 2]

    def _clear_metrics(self):
        """Clear metrics local values."""
        for metric in self._metric_fns.values():
            metric.clear()

    def _update_metrics(self, outputs):
        """Update metrics local values."""
        if not isinstance(outputs, tuple):
            raise ValueError("The `outputs` is not tuple.")

        if self._eval_indexes is not None and len(outputs) < 3:
            raise ValueError("The length of `outputs` must be greater than or equal to 3, \
                             but got {}".format(len(outputs)))

        for metric in self._metric_fns.values():
            if self._eval_indexes is None:
                metric.update(*outputs)
            else:
                if isinstance(metric, Loss):
                    metric.update(outputs[self._eval_indexes[0]])
                else:
                    metric.update(outputs[self._eval_indexes[1]], outputs[self._eval_indexes[2]])

    def _get_metrics(self):
        """Get metrics local values."""
        metrics = dict()
        for key, value in self._metric_fns.items():
            metrics[key] = value.eval()
        return metrics

    def _get_scaling_sens(self):
        """get the scaling sens"""
        scaling_sens = 1
        if self._loss_scale_manager is not None:
            scaling_sens = self._loss_scale_manager.get_loss_scale()
        if self._parallel_mode == ParallelMode.DATA_PARALLEL:
            scaling_sens /= self._device_number
        return scaling_sens

    def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True):
        """
        Training.

        Args:
            epoch (int): Total number of iterations on the data.
            train_dataset (Dataset): A training dataset iterator. If there is no
                                     loss_fn, a tuple with multiply data (data1, data2, data3, ...) will be
                                     returned and passed to the network. Otherwise, a tuple (data, label) will
                                     be returned, and the data and label are passed to the network and loss
                                     function respectively.
            callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None.
            dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
        """
        epoch = check_int_positive(epoch)
        self._train_network.set_train()

        if self._parameter_broadcast:
            self._train_network.set_broadcast_flag()

        # build callback list
        list_callback = _build_callbacks(callbacks)
        cb_params = _InternalCallbackParam()
        cb_params.train_network = self._train_network
        cb_params.epoch_num = epoch
        cb_params.batch_num = train_dataset.get_dataset_size()
        cb_params.mode = "train"
        cb_params.loss_fn = self._loss_fn
        cb_params.optimizer = self._optimizer
        cb_params.parallel_mode = self._parallel_mode
        cb_params.device_number = self._device_number
        cb_params.train_dataset = train_dataset
        cb_params.list_callback = list_callback

        if dataset_sink_mode and context.get_context("mode") == context.GRAPH_MODE:
            self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
        else:
            self._train_process(epoch, train_dataset, list_callback, cb_params)

    def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
        """
        Training process. The data would be passed to network through dataset channel.

        Args:
            epoch (int): Total number of iterations on the data.
            train_dataset (Dataset): A training dataset iterator. If there is no
                                     loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
                                     returned and passed to the network. Otherwise, a tuple (data, label) should
                                     be returned, and the data and label are passed to the network and loss
                                     function respectively.
            list_callback (_ListCallback): Executor of callback list. Default: None.
            cb_params (_InternalCallbackParam): Callback parameters. Default: None.
        """
        # remove later to deal with loop sink
        need_wrap = False
        if not hasattr(train_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink"):
            need_wrap = True

        dataset_helper = DatasetHelper(train_dataset)
        # remove later to deal with loop sink
        if need_wrap:
            self._train_network = DataWrapper(self._train_network, *(dataset_helper.types_shapes()),
                                              train_dataset.__ME_INITED__)
            cb_params.train_network = self._train_network
            self._train_network.set_train()

        cb_params.cur_step_num = 0
        loop_size = dataset_helper.loop_size()
        run_context = RunContext(cb_params)
        _callback_wrapper(list_callback, run_context, "begin")

        # used to stop training for early stop, such as stopAtTIme or stopATStep
        should_stop = False
        for i in range(epoch):
            cb_params.cur_epoch_num = i + 1
            _callback_wrapper(list_callback, run_context, "epoch_begin")

            # for data sink dataset_helper only iter once, other wise iter epoch_size times.
            for inputs in dataset_helper:
                cb_params.cur_step_num += loop_size
                _callback_wrapper(list_callback, run_context, "step_begin")
                outputs = self._train_network(*inputs)
                cb_params.net_outputs = outputs
                _callback_wrapper(list_callback, run_context, "step_end")

            _callback_wrapper(list_callback, run_context, "epoch_end")
            should_stop = should_stop or run_context.get_stop_requested()
            if should_stop:
                break

        _callback_wrapper(list_callback, run_context, "end")

    def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
        """
        Training process. The data would be passed to network directly.

        Args:
            epoch (int): Total number of iterations on the data.
            train_dataset (Dataset): A training dataset iterator. If there is no
                                     loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
                                     returned and passed to the network. Otherwise, a tuple (data, label) should
                                     be returned, and the data and label are passed to the network and loss
                                     function respectively.
            list_callback (_ListCallback): Executor of callback list. Default: None.
            cb_params (_InternalCallbackParam): Callback parameters. Default: None.
        """
        dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False)
        cb_params.cur_step_num = 0
        run_context = RunContext(cb_params)
        _callback_wrapper(list_callback, run_context, "begin")
        # used to stop training for early stop, such as stopAtTIme or stopATStep
        should_stop = False

        for i in range(epoch):
            cb_params.cur_epoch_num = i + 1

            _callback_wrapper(list_callback, run_context, "epoch_begin")

            for next_element in dataset_helper:
                len_element = len(next_element)
                if self._loss_fn and len_element != 2:
                    raise ValueError("when loss_fn is not None, train_dataset should"
                                     "return two elements, but got {}".format(len_element))
                cb_params.cur_step_num += 1
                _callback_wrapper(list_callback, run_context, "step_begin")

                overflow = False
                if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
                    scaling_sens = self._get_scaling_sens()
                    next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)

                outputs = self._train_network(*next_element)
                cb_params.net_outputs = outputs
                if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
                    _, overflow, _ = outputs
                    overflow = np.all(overflow.asnumpy())
                    self._loss_scale_manager.update_loss_scale(overflow)

                _callback_wrapper(list_callback, run_context, "step_end")
                should_stop = should_stop or run_context.get_stop_requested()
                if should_stop:
                    break

            train_dataset.reset()

            _callback_wrapper(list_callback, run_context, "epoch_end")
            should_stop = should_stop or run_context.get_stop_requested()
            if should_stop:
                break

        _callback_wrapper(list_callback, run_context, "end")

    def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True):
        """
        Training API where the iteration is controlled by python front-end.

        Configure to pynative mode, the training will be performed with dataset non-sink mode.

        Note:
            CPU is not supported when dataset_sink_mode is true.

        Args:
            epoch (int): Total number of iterations on the data.
            train_dataset (Dataset): A training dataset iterator. If there is no
                                     loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
                                     returned and passed to the network. Otherwise, a tuple (data, label) should
                                     be returned, and the data and label are passed to the network and loss
                                     function respectively.
            callbacks (list): List of callback object. Callbacks which should be excuted while training. Default: None.
            dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.


        Examples:
            >>> dataset = get_dataset()
            >>> net = Net()
            >>> loss = nn.SoftmaxCrossEntropyWithLogits()
            >>> loss_scale_manager = FixedLossScaleManager()
            >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
            >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
            >>> model.train(2, dataset)
        """
        repeat_count = train_dataset.get_repeat_count()
        if epoch != repeat_count and dataset_sink_mode is True:
            logger.warning(f"The epoch_size {epoch} is not the same with dataset repeat_count {repeat_count}")
        check_bool(dataset_sink_mode)
        _device_number_check(self._parallel_mode, self._device_number)
        _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)

        if context.get_context("device_target") in ["CPU", "GPU"] and context.get_context("enable_loop_sink"):
            raise ValueError("CPU and GPU can't support loop sink, please set enable_loop_sink=False.")

        self._train(epoch,
                    train_dataset,
                    callbacks=callbacks,
                    dataset_sink_mode=dataset_sink_mode)

    def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
        """
        Evaluation. The data would be passed to network through dataset channel.

        Args:
            valid_dataset (Dataset): Dataset to evaluate the model.
            list_callback (ListCallback): Executor of callback list. Default: None.
            cb_params (_InternalCallbackParam): Callback parameters. Default: None.

        Returns:
            Dict, returns the loss value & metrics values for the model in test mode.
        """
        _device_number_check(self._parallel_mode, self._device_number)

        run_context = RunContext(cb_params)

        # remove later to deal with loop sink
        need_wrap = False
        if not hasattr(valid_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink"):
            need_wrap = True

        valid_dataset.__loop_size__ = 1
        dataset_helper = DatasetHelper(valid_dataset)

        # remove later to deal with loop sink
        if need_wrap:
            self._eval_network = DataWrapper(self._eval_network, *(dataset_helper.types_shapes()),
                                             valid_dataset.__ME_INITED__)
            self._eval_network.set_train(mode=False)
            self._eval_network.phase = 'eval'

        list_callback.begin(run_context)

        for inputs in dataset_helper:
            cb_params.cur_step_num += 1
            list_callback.step_begin(run_context)

            outputs = self._eval_network(*inputs)

            cb_params.net_outputs = outputs
            list_callback.step_end(run_context)
            self._update_metrics(outputs)

        metrics = self._get_metrics()
        cb_params.metrics = metrics
        list_callback.end(run_context)

        return metrics

    def _eval_process(self, valid_dataset, list_callback=None, cb_params=None):
        """
        Evaluation. The data would be passed to network directly.

        Args:
            valid_dataset (Dataset): Dataset to evaluate the model.
            list_callback (ListCallback): Executor of callback list. Default: None.
            cb_params (_InternalCallbackParam): Callback parameters. Default: None.

        Returns:
            Dict, returns the loss value & metrics values for the model in test mode.
        """
        run_context = RunContext(cb_params)
        list_callback.begin(run_context)

        dataset_helper = DatasetHelper(valid_dataset, dataset_sink_mode=False)
        for next_element in dataset_helper:
            list_callback.step_begin(run_context)
            outputs = self._eval_network(*next_element)
            cb_params.net_outputs = outputs
            list_callback.step_end(run_context)
            self._update_metrics(outputs)

        metrics = self._get_metrics()
        cb_params.metrics = metrics
        list_callback.end(run_context)
        return metrics

    def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True):
        """
        Evaluation API where the iteration is controlled by python front-end.

        Configure to pynative mode, the evaluation will be performed with dataset non-sink mode.

        Note:
            CPU is not supported when dataset_sink_mode is true.

        Args:
            valid_dataset (Dataset): Dataset to evaluate the model.
            callbacks (list): List of callback object. Callbacks which should be excuted
                              while training. Default: None.
            dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.

        Returns:
            Dict, returns the loss value & metrics values for the model in test mode.

        Examples:
            >>> dataset = get_dataset()
            >>> net = Net()
            >>> loss = nn.SoftmaxCrossEntropyWithLogits()
            >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
            >>> model.eval(dataset)
        """
        check_bool(dataset_sink_mode)
        if not self._metric_fns:
            raise ValueError("metric fn can not be None or empty.")

        list_callback = _build_callbacks(callbacks)
        cb_params = _InternalCallbackParam()
        cb_params.eval_network = self._eval_network
        cb_params.valid_dataset = valid_dataset
        cb_params.batch_num = valid_dataset.get_dataset_size()
        cb_params.mode = "eval"
        cb_params.cur_step_num = 0

        self._eval_network.set_train(mode=False)
        self._eval_network.phase = 'eval'

        self._clear_metrics()

        if dataset_sink_mode and context.get_context("mode") == context.GRAPH_MODE:
            return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
        return self._eval_process(valid_dataset, list_callback, cb_params)

    def predict(self, *predict_data):
        """
        Generates output predictions for the input samples.

        Data could be single tensor, or list of tensor, tuple of tensor.

        Note:
            Batch data should be put together in one tensor.

        Args:
           predict_data (Tensor): Tensor of predict data. can be array, list or tuple.

        Returns:
            Tensor, array(s) of predictions.

        Examples:
            >>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mstype.float32)
            >>> model = Model(Net())
            >>> model.predict(input_data)
        """
        if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
            self._network = _VirtualDatasetCell(self._network)

        self._network.set_train(False)
        check_input_data(*predict_data, data_class=Tensor)
        result = self._network(*predict_data)

        check_output_data(result)
        return result


__all__ = ["Model"]
